% Implement generalized Bierens max test based on the particle swarm optimization (PSO) algorithm.
% 
% Arguments:
% datau : n*1 vector of generalized residuals
% dataw : n*p vector of covariates
% bnd : a constant for parameter space [-bnd,+bnd]^p  
% lambdaset : J possible values of lambda
% eps_constant : a small constant to determined which coefficient is zero
% 
% Outputs:
% test_stat_all  : J*1 of the test statistics
% gamma_hat_all  : J*p of the optimized coefficients
% gamma_hat_zero_all : J*1 of the number of zeros in the optimized coefficients
% 
% Examples:
% [test, gammahat, nzero] = pen_Bierens_test_centered(u,w,10,[1,0.5],1e-6,100)

function [test_stat_all,gamma_hat_all,gamma_hat_zero_all,rtime]...
    = pen_Bierens_test(datau,dataw,bnd,lambdaset,eps_constant,pso_nmax,varargin)
ip = inputParser;
addParameter(ip, 'swarmsize', 200);
addParameter(ip, 'stalliter', 200);  % last 'stalliter' iterations are examined to determine when to stop
addParameter(ip, 'tol', 5e-2);       % tolerance level used for stopping rule
addParameter(ip, 'dm', 0);           % indicator for exponential weight demeaning.
parse(ip, varargin{:});
swarmsize   = ip.Results.swarmsize;
stalliter   = ip.Results.stalliter;
tol         = ip.Results.tol;
dm          = ip.Results.dm;

lambdasort = sort(lambdaset,'descend');

J = length(lambdasort);
p = size(dataw,2);

test_stat_all = zeros(J,1);
gamma_hat_all = zeros(J,p);
gamma_hat_zero_all = zeros(J,1);

% Initialize for iterative procedure
test_stat = -1e10;

for i_lambda = 1:J
  lambda = lambdaset(i_lambda);

  w  = atan(zscore(dataw));
  mu = datau;
  
  ftn_gm = @(gm) hdm(gm,mu,w,lambda,'dm',dm);

  lb = -bnd*ones(1,p);
  ub =  bnd*ones(1,p);

  options = optimoptions('particleswarm','Display','off','SwarmSize',swarmsize,'MaxStallIterations',stalliter,'FunctionTolerance',tol);

  i_counter = 1;

  Tstart = tic;
  while i_counter <= pso_nmax

    [x,fval,~,~] = particleswarm(ftn_gm,p,lb,ub,options);   
    
    if (-fval) > test_stat - 1e-2  % check whether the test statistic in non-decreasing with lambda
      break
    end
    i_counter = i_counter + 1;
  end

  if isequal(i_counter, pso_nmax)
    fval = -test_stat;
  end

  Tend            = toc(Tstart);
  rtime(i_lambda) = Tend;

  test_stat = -fval;
  gamma_hat = x;

  gamma_hat_zero_var = (abs(gamma_hat) < eps_constant);
  gamma_hat_zero     = sum(gamma_hat_zero_var);     

  test_stat_all(i_lambda)      = test_stat;
  gamma_hat_all(i_lambda,:)    = gamma_hat;
  gamma_hat_zero_all(i_lambda) = gamma_hat_zero;  

end
end